/*
libndi
Copyright (C) 2020 VideoLAN

This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.

This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
*/

#define _POSIX_C_SOURCE 200809L
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <inttypes.h>
#include <time.h>

#include <sys/types.h>
#ifdef _WIN32
#include <ws2tcpip.h>
#include <winsock2.h>
#else
#include <sys/socket.h>
#include <netdb.h>
#endif

#ifdef HAVE_POLL
#include <poll.h>
#else
#include "compat/poll.c"
#endif

#include <libavutil/common.h>
#include <libavutil/frame.h>
#include <libavutil/fifo.h>

#include "libndi.h"

uint8_t ndi_xortab[] = {
  0x4e, 0x44, 0x49, 0xae, 0x2c, 0x20, 0xa9, 0x32, 0x30, 0x31, 0x37, 0x20,
  0x4e, 0x65, 0x77, 0x54, 0x65, 0x6b, 0x2c, 0x20, 0x50, 0x72, 0x6f, 0x70,
  0x72, 0x69, 0x65, 0x74, 0x79, 0x20, 0x61, 0x6e, 0x64, 0x20, 0x43, 0x6f,
  0x6e, 0x66, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x2e, 0x20,
  0x59, 0x6f, 0x75, 0x20, 0x61, 0x72, 0x65, 0x20, 0x69, 0x6e, 0x20, 0x76,
  0x69, 0x6f, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x6f, 0x66, 0x20,
  0x74, 0x68, 0x65, 0x20, 0x4e, 0x44, 0x49, 0xae, 0x20, 0x53, 0x44, 0x4b,
  0x20, 0x6c, 0x69, 0x63, 0x65, 0x6e, 0x73, 0x65, 0x20, 0x61, 0x74, 0x20,
  0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x6e, 0x65, 0x77, 0x2e, 0x74,
  0x6b, 0x2f, 0x6e, 0x64, 0x69, 0x73, 0x64, 0x6b, 0x5f, 0x6c, 0x69, 0x63,
  0x65, 0x6e, 0x73, 0x65, 0x2f, 0x00, 0x00, 0x00
};

typedef struct ndi_message
{
    uint8_t *buf;
    int len;
} ndi_message;

struct ndi_ctx
{
    /* buffers */
    AVFifoBuffer *fifo;

    int target_size;

    int socket_fd;

    ndi_message ndi_request[4];
    int pending_requests;

    ndi_data_cb callback;
    void *user_data;

    /* options */
    char *ip;
    char *port;
};

/* Probably could merge scramble and unscramble */
static void ndi_scramble_type1(uint8_t *buf, int len, uint32_t seed)
{
    uint64_t seed64 = ((uint64_t)seed << 32) | seed;
    uint64_t seed1  = seed64 ^ 0xb711674bd24f4b24ULL;
    uint64_t seed2  = seed64 ^ 0xb080d84f1fe3bf44ULL;

    if(len > 7) {
        uint64_t *buf64 = (uint64_t*)buf;
        int qwords = len / 8;
        uint64_t tmp1 = seed1;
        for(int i = 0; i < qwords; i++) {
            seed1 = seed2;
            tmp1 ^= (tmp1 << 23);
            tmp1  = ((seed1 >> 9 ^ tmp1) >> 17) ^ tmp1 ^ seed1;
            seed2 = tmp1 ^ buf64[i];
            buf64[i] ^= tmp1 + seed1;
            tmp1 = seed1;
        }

        buf  = buf + qwords * 8;
        len -= qwords * 8;
    }

    if(len) {
        uint64_t remainder = 0;
        memcpy(&remainder, buf, len);
        seed1 ^= seed1 << 23;
        seed1 = ((seed2 >> 9 ^ seed1) >> 17) ^ seed1 ^ seed2;
        remainder ^= seed1 + seed2;
        memcpy(buf, &remainder, len);
    }
}

static void ndi_unscramble_type1(uint8_t *buf, int len, uint32_t seed)
{
    uint64_t seed64 = ((uint64_t)seed << 32) | seed;
    uint64_t seed1  = seed64 ^ 0xb711674bd24f4b24ULL;
    uint64_t seed2  = seed64 ^ 0xb080d84f1fe3bf44ULL;

    if(len > 7) {
        uint64_t *buf64 = (uint64_t*)buf;
        int qwords = len / 8;
        uint64_t tmp1 = seed1;
        for(int i = 0; i < qwords; i++) {
            seed1 = seed2;
            tmp1 ^= (tmp1 << 23);
            tmp1  = ((seed1 >> 9 ^ tmp1) >> 17) ^ tmp1 ^ seed1;
            buf64[i] ^= tmp1 + seed1;
            seed2 = tmp1 ^ buf64[i];
            tmp1 = seed1;
        }

        buf  = buf + qwords * 8;
        len -= qwords * 8;
    }

    if(len) {
        uint64_t remainder = 0;
        memcpy(&remainder, buf, len);
        seed1 ^= seed1 << 23;
        seed1 = ((seed2 >> 9 ^ seed1) >> 17) ^ seed1 ^ seed2;
        remainder ^= seed1 + seed2;
        memcpy(buf, &remainder, len);
    }
}

static void ndi_unscramble_type2(uint8_t *buf, int len, uint32_t seed)
{
    int xor_len = 128;

    if(len >= 8) {
        uint64_t *buf64 = (uint64_t*)buf;
        int len8 = len >> 3;
        int64_t tmp;
        for(int i = 0; i < len8; i++) {
            tmp = seed;
            seed = buf64[i] & 0xffffffff;
            buf64[i] = ((tmp * len * -0x61c8864680b583ebLL + 0xc42bd7dee6270f1bLL) ^ buf64[i]) * -0xe217c1e66c88cc3LL + 0x2daa8c593b1b4591LL;
        }
    }

    if(len < xor_len)
        xor_len = len;

    for(int i = 0; i < xor_len; i++)
        buf[i] ^= ndi_xortab[i];
}

static void create_text_message(uint8_t *dst, int *dst_len, char *payload, int payload_len)
{
    /* Version/Scrambling type */
    dst[0] = 0x01;
    dst[1] = 0x80;

    /* Message Type */
    dst[2] = NDI_DATA_TEXT;
    dst[3] = 0;

    /* Header Length */
    dst[4] = 8;
    dst[5] = 0;
    dst[6] = 0;
    dst[7] = 0;

    /* Payload Length */
    dst[8] = payload_len;
    dst[9] = 0;
    dst[10] = 0;
    dst[11] = 0;

    /* 8 bytes of zero */
    memset(&dst[12], 0, 8);

    /* Copy payload */
    memcpy(&dst[20], payload, payload_len);

    *dst_len = 20 + payload_len;
}

static int process_video_message(ndi_ctx *ndi_ctx, uint8_t *data, int header_len, int payload_len)
{
    ndi_data ndi_data = {0};

    uint32_t fourcc = data[0] | (data[1] << 8) | (data[2] << 16) | (data[3] << 24);
    uint32_t width  = (data[7] << 24) | (data[6] << 16) | (data[5] << 8) | data[4];
    uint32_t height = (data[11] << 24) | (data[10] << 16) | (data[9] << 8) | data[8];
    uint32_t fps_num = (data[15] << 24) | (data[14] << 16) | (data[13] << 8) | data[12];
    uint32_t fps_den = (data[19] << 24) | (data[18] << 20) | (data[17] << 8) | data[16];

    // XXX: some more things in the header

    ndi_data.data_type = NDI_DATA_VIDEO;
    ndi_data.data = data+header_len;
    ndi_data.len = payload_len;

    ndi_data.fourcc = fourcc;
    ndi_data.fps_num = fps_num;
    ndi_data.fps_den = fps_den;
    ndi_data.width = width;
    ndi_data.height = height;

    ndi_ctx->callback(&ndi_data, ndi_ctx->user_data);

    return 0;
}

static int process_audio_message(ndi_ctx *ndi_ctx, uint8_t *data, int header_len, int payload_len)
{
    (void)payload_len; // XXX: Why is this unused?

    int ret = 0;
    ndi_data ndi_data = {0};
    uint32_t fourcc = data[0] | (data[1] << 8) | (data[2] << 16) | (data[3] << 24);
    uint32_t samples  = (data[7] << 24) | (data[6] << 16) | (data[5] << 8) | data[4];
    uint32_t num_channels = (data[11] << 24) | (data[10] << 16) | (data[9] << 8) | data[8];
    uint32_t sample_rate = (data[15] << 24) | (data[14] << 16) | (data[13] << 8) | data[12];
    float scale_factors[16];
    uint32_t num_nonzero_channels = 0;
    uint16_t bps = sizeof(int16_t);

    // XXX: some more things in the header
    data += header_len;

    if(fourcc == MKTAG('f','o','w','t')) {
        bps = sizeof(float);
        for(uint32_t i = 0; i < num_channels; i++) {
            uint32_t tmp = data[0] | (data[1] << 8) | (data[2] << 16) | (data[3] << 24);
            memcpy(&scale_factors[i], &tmp, sizeof(float));
            if(scale_factors[i] != 0.0f)
                num_nonzero_channels++;

            data += sizeof(float);
        }
    }
    else if(fourcc == MKTAG('s','o','w','t')) {
        for(uint32_t i = 0; i < num_channels; i++) {
            scale_factors[i] = 1.0f;
        }
        num_nonzero_channels = num_channels;
    }

    // XXX: Bounds check audio samples

    ndi_data.data_type = NDI_DATA_AUDIO;
    ndi_data.fourcc = fourcc;
    ndi_data.samples = samples;
    ndi_data.num_channels = num_channels;
    ndi_data.sample_rate = sample_rate;

    for(uint32_t i = 0; i < num_channels; i++) {
        ndi_data.buf[i] = av_buffer_alloc(ndi_data.num_channels * ndi_data.samples * bps);
        if(!ndi_data.buf[i]) {
            ret = -1;
            goto end;
        }
    }

    for(uint32_t j = 0; j < samples; j++) {
        for(uint32_t i = 0; i < num_channels; i++) {
            if(scale_factors[i] == 0.0f)
                memset(&ndi_data.buf[i]->data[4*j], 0, bps);
            else {
                if(bps == 2) {
                    ndi_data.buf[i]->data[2*j+0] = data[1];
                    ndi_data.buf[i]->data[2*j+1] = data[0];
                } else if(bps == 4) {
                    float sf = scale_factors[i] / 32767.0f;
                    int16_t sample = ((uint16_t)data[1] << 8) | data[0];
                    sf *= sample;
                    memcpy(&ndi_data.buf[i]->data[4*j], &sf, sizeof(sf));
                }
                data += sizeof(int16_t);
            }
        }
    }

    ndi_ctx->callback(&ndi_data, ndi_ctx->user_data);

end:
    for(uint32_t i = 0; i < num_channels; i++)
        av_buffer_unref(&ndi_data.buf[i]);

    return ret;
}

static void test_scramblev1(void)
{
    srand( time(NULL) );
    uint8_t buf[23], buf2[23];
    uint32_t seed = rand();

    /* Generate random numbers */
    for(size_t i = 0; i < sizeof(buf); i++)
        buf[i] = rand();

    memcpy(buf2, buf, sizeof(buf));
    ndi_scramble_type1(buf, sizeof(buf), seed);
    ndi_unscramble_type1(buf, sizeof(buf), seed);

    int ret = memcmp(buf, buf2, sizeof(buf));

    if(ret)
        fprintf(stderr, "scrambling mismatch \n");
}

static void process_ndi_packet(ndi_ctx *ndi_ctx, uint8_t *data, int len)
{
    uint32_t seed;
    (void)len; // FIXME: Actually check length properly!

    /* MSB = scrambled bit */
    uint16_t header_type = data[0] | (data[1] << 8);

    uint16_t message_type = data[2] | (data[3] << 8);
    uint8_t scrambling_type = 1;

    uint32_t header_size = data[4] | (data[5] << 8) | (data[6] << 16) | (data[7] << 24);
    uint32_t payload_len = data[8] | (data[9] << 8) | (data[10] << 16) | (data[11] << 24);
    seed = header_size + payload_len;

    if(message_type == NDI_DATA_VIDEO && header_type > 3)
        scrambling_type = 2;
    else if(message_type == NDI_DATA_AUDIO && header_type > 2)
        scrambling_type = 2;
    else if(message_type == NDI_DATA_TEXT && header_type > 2)
        scrambling_type = 2;

    if(scrambling_type == 1){
        if(message_type == NDI_DATA_TEXT)
            ndi_unscramble_type1(data+12, seed, seed);
        else
            ndi_unscramble_type1(data+12, header_size, seed);
    } else {
        if(message_type == NDI_DATA_TEXT)
            ndi_unscramble_type2(data+12, seed, seed);
        else
            ndi_unscramble_type2(data+12, header_size, seed);
    }

    int data_len = 12 + header_size + payload_len;
    //printf("process %i %u %u %u %u \n", data_len, header_size, payload_len, message_type, scrambling_type);

    if(message_type == NDI_DATA_VIDEO)
        process_video_message(ndi_ctx, data+12, header_size, payload_len);
    else if(message_type == NDI_DATA_AUDIO)
        process_audio_message(ndi_ctx, data+12, header_size, payload_len);
}

static int handle_ndi_packet(ndi_ctx *ndi_ctx)
{
    if(!ndi_ctx->target_size) {
        uint8_t data[12];
        av_fifo_generic_peek(ndi_ctx->fifo, data, 12, NULL); // fixme

        /* MSB = scrambled bit */
        uint16_t header_type = data[0] | (data[1] << 8);

        if(header_type >> 15) {
            header_type &= 0x7fff;
            uint16_t message_type = data[2] | (data[3] << 8);

            uint32_t header_size = data[4] | (data[5] << 8) | (data[6] << 16) | (data[7] << 24);
            uint32_t payload_len = data[8] | (data[9] << 8) | (data[10] << 16) | (data[11] << 24);
            ndi_ctx->target_size = 12 + header_size + payload_len;
            //printf("%u %u %u \n", header_size, payload_len, message_type);
        }

        //printf("target size %i \n", ndi_ctx->target_size);
        if(av_fifo_space(ndi_ctx->fifo) < ndi_ctx->target_size)
            av_fifo_grow(ndi_ctx->fifo, ndi_ctx->target_size * 3 / 2);
    }

    if(av_fifo_size(ndi_ctx->fifo) >= ndi_ctx->target_size) {
        /* FIXME: make this zero copy */
        uint8_t *data = malloc(ndi_ctx->target_size);
        if(!data)
            return -1;

        av_fifo_generic_read(ndi_ctx->fifo, data, ndi_ctx->target_size, NULL);
        process_ndi_packet(ndi_ctx, data, ndi_ctx->target_size);
        //printf("draining %i \n", ndi_ctx->target_size);
        free(data);
        ndi_ctx->target_size = 0;
    }

    return 0;
}

static int receive_ndi_packet(ndi_ctx *ndi_ctx)
{
    uint8_t tmp[5000];

    /* TODO: Zero copy */
    int len = recv(ndi_ctx->socket_fd, (void *)tmp, 5000, 0);
    if(len < 0)
        printf("bad \n");

    if(len == 0)
        printf("end \n");

    if(av_fifo_space(ndi_ctx->fifo) < len)
        av_fifo_grow(ndi_ctx->fifo, 5000); // fixme

    av_fifo_generic_write(ndi_ctx->fifo, tmp, len, NULL);

    if(handle_ndi_packet(ndi_ctx) < 0) {
        printf("handle fail \n");
        return -1;
    }

    return 0;
}

static int request_ndi_data(ndi_ctx *ndi_ctx)
{
    int pending = ndi_ctx->pending_requests;
    for(int i = 0; i < pending; i++) {
        ndi_message *ndi_request = &ndi_ctx->ndi_request[i];
        // XXX: Check failure
        send(ndi_ctx->socket_fd, (void *)ndi_request->buf, ndi_request->len, 0);
        free(ndi_request->buf);
        ndi_request->buf = NULL;
        ndi_ctx->pending_requests--;
    }

    return 0;
}

ndi_ctx *libndi_init(void)
{
    ndi_ctx *ndi_ctx = calloc(1, sizeof(*ndi_ctx));
    if(!ndi_ctx) {
        fprintf(stderr, "malloc failed \n");
        return NULL;
    }

    ndi_ctx->fifo = av_fifo_alloc(10000);
    if(!ndi_ctx->fifo)
        goto end;

    return ndi_ctx;

end:
    libndi_close(ndi_ctx);
    return NULL;

}

int libndi_setup(ndi_ctx *ndi_ctx, ndi_opts *ndi_opts)
{
    if(ndi_ctx->ip) {
        free(ndi_ctx->ip);
        ndi_ctx->ip = NULL;
    }

    if(ndi_ctx->port) {
        free(ndi_ctx->port);
        ndi_ctx->port = NULL;
    }

    if(!ndi_opts->ip || !ndi_opts->port) {
        fprintf(stderr, "IP or port not set \n");
        return -1;
    }

    ndi_ctx->ip = strdup(ndi_opts->ip);
    if(!ndi_ctx->ip) {
        fprintf(stderr, "Malloc failed \n");
        return -1;
    }

    ndi_ctx->port = strdup(ndi_opts->port);
    if(!ndi_ctx->port) {
        fprintf(stderr, "Malloc failed \n");
        return -1;
    }

    /* Create and scramble request messages */
    char *tx_msgs[4] = {
        "<ndi_version text=\"3\" video=\"4\" audio=\"3\" sdk=\"3.5.1\" platform=\"LINUX\"/>",
        "<ndi_video quality=\"high\"/>",
        "<ndi_enabled_streams video=\"true\" audio=\"true\" text=\"true\"/>",
    };

    char tally_msg[64];
    int ret = snprintf(tally_msg, sizeof(tally_msg),
        "<ndi_tally on_program=\"%s\" on_preview=\"%s\"/>",
        (ndi_opts->initial_tally_state == NDI_TALLY_LIVE) ? "true" : "false",
        (ndi_opts->initial_tally_state == NDI_TALLY_PREVIEW) ? "true" : "false");

    if (ret < 0 || ret >= (int)sizeof(tally_msg))
        return -1;

    tx_msgs[3] = tally_msg;

    for(size_t i = 0; i < sizeof(tx_msgs) / sizeof(*tx_msgs); i++) {
        size_t payload_len = strlen(tx_msgs[i]) + 1;
        ndi_message *ndi_request = &ndi_ctx->ndi_request[i];
        ndi_request->buf = calloc(1, payload_len + 20);
        if(!ndi_request->buf) {
            fprintf(stderr, "Malloc failed \n");
            return -1;
        }

        int dst_len = 0;
        create_text_message(ndi_request->buf, &dst_len, tx_msgs[i], payload_len);
        ndi_scramble_type1(ndi_request->buf+12, 8+payload_len, 8+payload_len);
        ndi_request->len = dst_len;
    }

    ndi_ctx->pending_requests = 4;

    return 0;
}

#ifdef _WIN32
static int system_InitWSA(int hi, int lo) {
    WSADATA data;

    if (WSAStartup(MAKEWORD(hi, lo), &data) == 0) {
        if (LOBYTE(data.wVersion) == 2 && HIBYTE(data.wVersion) == 2) return 0;
        /* Winsock DLL is not usable */
        WSACleanup();
    }
    return -1;
}

static void system_Init(void) {
    if (system_InitWSA(2, 2) && system_InitWSA(1, 1))
        fputs("Error: cannot initialize Winsocks\n", stderr);
}
#else
static void system_Init(void) { return; }
#endif

void libndi_receive_data(ndi_ctx *ndi_ctx, ndi_data_cb callback, void *user_data)
{
    ndi_ctx->callback = callback;
    ndi_ctx->user_data = user_data;

    /* connect to socket */
    int ret;
    struct addrinfo hints, *res, *p;
    memset(&hints, 0, sizeof hints);
    hints.ai_family = AF_UNSPEC;
    hints.ai_socktype = SOCK_STREAM;

    system_Init();

    if ((ret = getaddrinfo(ndi_ctx->ip, ndi_ctx->port, &hints, &res)) != 0) {
        fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(ret));
        goto end;
    }

    for(p = res; p != NULL; p = p->ai_next) {
        ndi_ctx->socket_fd = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
        if(ndi_ctx->socket_fd < 0)
            continue;

        ret = connect(ndi_ctx->socket_fd, res->ai_addr, res->ai_addrlen);
        if(ret < 0) {
            printf("can't connect \n");
            freeaddrinfo(res);
            goto end;
        }
    }

    freeaddrinfo(res);

    struct pollfd fds[1];
    fds[0].fd = ndi_ctx->socket_fd;
    fds[0].events = POLLIN | POLLOUT;
    while(poll(fds, 1, 10000)) {
        if(fds[0].revents & POLLOUT) {
            if(request_ndi_data(ndi_ctx) < 0)
                goto end;

            if(!ndi_ctx->pending_requests)
                fds[0].events = POLLIN;
        }

        if(fds[0].revents & POLLIN)
            receive_ndi_packet(ndi_ctx);
    }

end:
    return;
}

void libndi_close(ndi_ctx *ndi_ctx)
{
    if(!ndi_ctx)
        return;

    if(ndi_ctx->ip)
        free(ndi_ctx->ip);

    if(ndi_ctx->port)
        free(ndi_ctx->port);

    for(int i = 0; i < 4; i++)
        free(ndi_ctx->ndi_request[i].buf);

    av_fifo_free(ndi_ctx->fifo);

    free(ndi_ctx);
}
